#include "common.h"

#include "jszt_cnn.h"
#include "cnntools.h"

Golad_CNN * gcnn = NULL;       //全局参数, 私用

pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
pthread_cond_t has_producer = PTHREAD_COND_INITIALIZER;

// #define  DEBUG_RUN_CNNING   LOG_DEBUG
#define  DEBUG_RUN_CNNING   

JSZT_CODE cnn_running(CNN_ONE *cnn, uint8 thindex, uint8 total) {
    DEBUG_RUN_CNNING("thindex: %d, stream: %p, dstream: %p, cnnker: %p, width: %d, height: %d, ckerlen: %d",  \
            thindex, cnn->stream, cnn->dstream, cnn->cnnker, cnn->width, cnn->height, cnn->ckerlen);
    //验证数据的正确性
    DEBUG_RUN_CNNING("stream[0]: %0x, stream[width-1]: %0x, dstream[0]: %d, dstream[width-1]: %0x", \
            cnn->stream[0], cnn->stream[cnn->width - 1], cnn->dstream[0], cnn->dstream[cnn->width-1]);
    
    uint8 *ptr      =   cnn->stream;
    uint8 *cptr     =   cnn->cnnker;
    uint8 win_size  =   cnn->ckerlen;
    int y_max       =   cnn->height;
    int x_max       =   cnn->width;
    int c_max       =   cnn->ckerlen;
    int y_offset    =   0;
    int x_offset    =   0;
    int cx_offset   =   0;
    int cy_offset   =   0;
    
    int x_border    =   x_max - c_max + 1;                                     //横向边界处理
    int y_border    =   thindex == total ? y_max - cnn->ckerlen + 1 : y_max;   //纵向边界处理

    uint8 *dptr     =   cnn->dstream + (thindex*gcnn->wind_rows*x_border);       //边界的处理会影响dstream的值

#if 1
    //验证 stream 数据是否每次都是跟主存中的一致的
    // printHex1(cptr, 81, 9, "cnnker");
    // printHex1(cnn->stream, 60, 60, "stream");
    // usleep(10);
    // goto BACK;
#endif

    /**
     * 窗体大小为卷积核窗口大小, 使用双循环
     *  第一层循环负责 纵向移动
     *  第二层循环负责 横向移动
     * 
     *  第三层和第四层负责计算
     * 
     * 现在是边界对齐方式, 理论计算出的矩阵大小为 (y_offset - cnn->ckerlen + 1) * (x_offset - cnn->ckerlen + 1);  // 用户可能会更改
     * 
    */
   
    int d_offset = 0;
    int dtemp = 0;
    int check_times = 0;

    for(y_offset = 0; y_offset < y_border; y_offset++) {
        for(x_offset = 0; x_offset < x_border; x_offset++) {
            dtemp = 0; 
            for(cy_offset = 0; cy_offset < c_max; cy_offset++) {
                for(cx_offset = 0; cx_offset < c_max; cx_offset++) {
                    dtemp += *(ptr + ((y_offset + cy_offset) * x_max + x_offset + cx_offset)) * *(cptr + (cy_offset*c_max + cx_offset));
                    if(dtemp > 255) {
                        dtemp = 0xFF;
                        goto WIN_CONTINE;
                    }
                }
            }

        WIN_CONTINE:
            // *(dptr + (y_offset*x_max+x_offset)) = dtemp;
            *(dptr + (y_offset*x_border+x_offset)) = dtemp;

            // 做个校验测试代码
            // if(*(dptr + (y_offset*x_border+x_offset)) != dtemp) {
            //     LOG_ERROR("y_offset: %d, x_offset: %d, val:%02x, but should be: %02x", \
            //             y_offset, x_offset, *(dptr + (y_offset*x_border+x_offset)), dtemp);
            //     check_times++;
            //     if(check_times > 6) {
            //         exit(-1);
            //     }
            // }

            // printf("offset -> y: %02d, x: %02d, val: %02x \n", y_offset, x_offset, dtemp);
        }
    }

BACK:
    return JS_OK;

}
int get_cker_offset(int times, uint8 *ckerlen) {
    int offset = 0;
    while(times-->0) {
        offset+=*ckerlen**ckerlen;
        *ckerlen++;
    }
    return offset;
}
void *start_CNN(void *Para) {
    CNN_ONE cnn_one;
    memset(&cnn_one, 0, sizeof(CNN_ONE));
    uint8 *stream           =   NULL;
    uint8 *cnnker           =   NULL;
    uint8 *dstream          =   NULL;
    CNN_PARA * cPara        =   NULL;
    TH_STATUS *th_status    =   (TH_STATUS *)Para;
    uint8 index             =   th_status->thread_index;  // 当前核数的序列号
    uint8 total             =   gcnn->thread_total - 1;

#ifdef BIND_KERNEL
    /**
     * 绑核处理
    */
    int cur_cpu = gcnn->max_cpus - index;                     // 从最后的核开始绑定
    cpu_set_t cpuset        =   { 0 };                        // CPU核的位掩码
    CPU_ZERO(&cpuset);
    CPU_SET(cur_cpu, &cpuset);

    if(0 != sched_setaffinity(0, sizeof(cpu_set_t), &cpuset)) {
        LOG_WARN("sched_setaffinity, thread index: %d, binding CPU: %d, %s", index, cur_cpu, strerror(errno));
    }
    
    CPU_ZERO(&cpuset);
    if(0 != sched_getaffinity(0, sizeof(cpu_set_t), &cpuset)) {
        LOG_WARN("sched_getaffinity,thread index: %d, binding CPU: %d, %s", index, cur_cpu, strerror(errno));
    }
    
    if(!CPU_ISSET(cur_cpu, &cpuset)) {
        LOG_WARN("Thread binding failure, thread index: %d, binding CPU: %d, %s", index, cur_cpu, strerror(errno));
    }

    LOG_INFO("Thread ids %d, Run success !", index + 1);
#endif // BIND_KERNEL 

    
    while(1) {
        while(1) {
            if(gcnn->status == JSCNN_START && th_status->status == TH_READY){
                break;
            }
            if(gcnn->status == JSCNN___END) {
                goto BACK;
            }
        };
        
        cPara   = gcnn->cPara;
        cnn_one.stream  = cPara->stream   + (index*cPara->height*gcnn->wind_rows);
        cnn_one.dstream = cPara->dstream;
        cnn_one.cnnker  = cPara->cnnker   + get_cker_offset(gcnn->cnn_times, cPara->ckrlen);
        cnn_one.ckerlen = *(cPara->ckrlen + gcnn->cnn_times);
        cnn_one.height  = index == total ? gcnn->wind_rows + gcnn->last_rows : gcnn->wind_rows;
        cnn_one.width   = cPara->width;

#if 0
        //测试单核性能
        int cnt = 200;
        struct timeval tv; 
        int startTimes, endTimes, diff; 
        gettimeofday(&tv,NULL);
        startTimes = tv.tv_sec;
        int times = 0;
        while(times++ < cnt) {
            int i, j;
            uint8 * ptr = cnn_one.stream;
            for(i = 0; i < cnn_one.height; i++) {
                for(j = 0; j < cnn_one.width; j++) {
                    *ptr++ =  rand() % 10; //生成随机数
                }
            }
            printf("==> %d", times);
            cnn_running(&cnn_one, index, total);
            writeWithIndex(gcnn->cPara, times);
        }
        gettimeofday(&tv,NULL);
        endTimes = tv.tv_sec;

        diff = endTimes - startTimes;   // 秒数

        LOG_DEBUG("cnt -> %d, diff: %d ", cnt, diff);
#endif
        // uint8 * ptr = cnn_one.stream;
        // int i, j;
        // for(i = 0; i < cnn_one.height; i++) {
        //     for(j = 0; j < cnn_one.width; j++) {
        //         *ptr++ =  rand() % 10; //生成随机数
        //     }
        // }
        cnn_running(&cnn_one, index, total);
        th_status->total_times++;
        DEBUG_RUN_CNNING("%d running ...", index);
        th_status->status = TH_OVER;
    }
BACK:
    DEBUG_RUN_CNNING("destory thread index: %d", index);
}

JSZT_CODE cnn_setReady() {
    int ret = 0;
    int thread_num = gcnn->thread_total;
    TH_STATUS *thread_status = gcnn->tstatus;
    
    while(thread_num-- > 0) {
        thread_status->thread_index = thread_num;
        // LOG_DEBUG("thread index: %d ", thread_status->thread_index);
        ret = pthread_create(&thread_status->thread_id, NULL, start_CNN, (void *)thread_status);
        if(ret != 0) {
            LOG_ERROR("create thread failed; %s", strerror(errno));
            gcnn->status = JSCNN___END;
            return JS_ERROR_CREATE__TH;
        }
        *thread_status++;
    }
    return JS_OK;
}

JSZT_CODE cnn_setEnd() {
    gcnn->status = JSCNN___END;
    
    JSZT_CODE flags = JS_OK;
    int thread_num = gcnn->thread_total;
    
    TH_STATUS *thread_status = gcnn->tstatus;
    while(thread_num-- > 0) {
        if(0 != pthread_join(thread_status->thread_id, NULL)) {
            flags = JS_ERROR_RELEASE_TH;
            LOG_ERROR("Error releasing thread");
        }
        *thread_status++;
    } 
    return flags;
}

JSZT_CODE init_CNN(int used_cpus, int max_cpus) {
    int ret = 0;
    
    if(used_cpus <= 0) {
        LOG_ERROR("cur mode used CPU nums is not %d", used_cpus);
        return JS_ERROR_CONFIG;
    }

    if(used_cpus > max_cpus) {
        LOG_ERROR("Local max CPU nums is %d, but used CPUS nums is %d", max_cpus, used_cpus);
        return JS_ERROR_CONFIG;
    }


    TH_STATUS *thread_status = (TH_STATUS *)js_malloc(sizeof(TH_STATUS)*used_cpus);
    if(NULL == thread_status) {
        LOG_ERROR("thread_status is malloc failed, malloc size is %d ", sizeof(TH_STATUS)*used_cpus);
        return JS_ERROR_NULL;
    }
    gcnn = (Golad_CNN *)js_malloc(sizeof(Golad_CNN));
    
    if(NULL == gcnn) {
        free(thread_status);
        LOG_ERROR("gcnn is malloc failed, malloc size is %d ", sizeof(Golad_CNN));
        return JS_ERROR_NULL;
    }

    gcnn->cnnTemp = (uint8 *)js_malloc(CNNTEMP_LEN);

    LOG_INFO("malloc gcnn: %p, th_status: %p", gcnn, thread_status);
    gcnn->tstatus = thread_status;
    gcnn->status = JSCNN_AWAIT;
    gcnn->thread_total = used_cpus;
    gcnn->max_cpus = max_cpus - 1;
    
    return JS_OK;
}

JSZT_CODE run_cnn_one() {
    int cnt = 0;
    CNN_PARA *cPara = gcnn->cPara;
    TH_STATUS *th_status = gcnn->tstatus;
    
    gcnn->wind_rows = (int)(cPara->height/gcnn->thread_total);
    if(gcnn->wind_rows < cPara->ckrlen[gcnn->cnn_times]) {
        LOG_ERROR("wind_rows is too small, but cnnker border is too lager");
        return JS_ERROR_BORDER;
    }
    gcnn->last_rows = cPara->height - gcnn->wind_rows*gcnn->thread_total;

    // LOG_DEBUG("height: %d, total: %d, win_rows: %d, last_rows: %d", \
    //         cPara->height, gcnn->thread_total, gcnn->wind_rows, gcnn->last_rows);

    cPara->dwidth  = cPara->width   - cPara->ckrlen[gcnn->cnn_times] + 1;
    cPara->dheight = cPara->height  - cPara->ckrlen[gcnn->cnn_times] + 1; 

    // LOG_INFO("Manage thread start ...");
    // 1. 先将子线程状态设置成就绪状态
    cnt = gcnn->thread_total;
    while(cnt-->0) {
        if((th_status+cnt)->status == TH_OVER) {
            (th_status+cnt)->status = TH_READY;
        }
    }

    // 2. 然后开启子线程模式计算
    gcnn->status = JSCNN_START;

    // 3. 检测子线程状态值,是否运行完毕
RETEST:
    // usleep(10*1000);       //不能加, 会降低效率
    cnt = gcnn->thread_total;
    while(cnt-- > 0) {
        if((th_status+cnt)->status == TH_READY) {
            goto RETEST;
        }
    }

    // 4. 设置主线程,第一次运行完毕, 处于等待状态, 然后等待进行下一次的计算
    gcnn->status = JSCNN_AWAIT;
    return JS_OK;
}

JSZT_CODE manage_CNN(const CNN_PARA *cPara) {

    // printStream(cPara);
    CNN_PARA *cnnPara = (CNN_PARA *)cPara;
    gcnn->cPara = (CNN_PARA *)cPara;

    if(NULL == gcnn) {
        LOG_ERROR("gcnn is NULL");
        return JS_ERROR_NULL;
    }

    if(NULL == gcnn->cnnTemp) {
        LOG_ERROR("cnnTemp is NULL");
        return JS_ERROR_NULL;
    }

    gcnn->cnn_times = 0;       //卷积运算的次数
    
    if(NULL == cPara->dstream) {
        cnnPara->dstream = gcnn->cnnTemp;
    }

#if 1
    int cnnklen = 0;
    uint8 * preStream = cnnPara->stream;
    int preWidth  = cnnPara->width;
    int preHeight = cnnPara->height;
    cnnPara->dcnt++;
    while(1) {
        cnnklen = cnnPara->ckrlen[gcnn->cnn_times];
        if(cnnklen != 0) {
            if(0 != gcnn->cnn_times) {
                cnnPara->width   = cnnPara->dwidth;
                cnnPara->height  = cnnPara->dheight;
                memcpy(cnnPara->stream, cnnPara->dstream, cnnPara->width*cnnPara->height);
            }
            // LOG_DEBUG("==> cnnPara->stream: %p, cnnPara->dstream: %p", cnnPara->stream, cnnPara->dstream);
            cnnPara->dwidth  = cnnPara->width  - cnnPara->ckrlen[gcnn->cnn_times] + 1;
            cnnPara->dheight = cnnPara->height - cnnPara->ckrlen[gcnn->cnn_times] + 1;
            // LOG_DEBUG("start cnn times: %d", gcnn->cnn_times);
            run_cnn_one();
            // writeWithCNNOne(cnnPara, cnnPara->dcnt, gcnn->cnn_times);    //验证多次计算结果是否正确
            gcnn->cnn_times++;
            continue;
        }
        break;
    }
    cnnPara->width  = preWidth;
    cnnPara->height = preHeight;
    // memcpy(cnnPara->stream, preStream, preWidth*preHeight);
#else
    run_cnn_one();
#endif

    return JS_OK;
}